Add Megatron-Bridge recipe-free distillation example script#861
Add Megatron-Bridge recipe-free distillation example script#861kevalmorabia97 merged 6 commits intomainfrom
Conversation
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughThe pull request extends the Megatron-Bridge examples with a comprehensive distillation workflow, including a new distill.py script for orchestrating student model distillation from teacher models, expanded documentation with end-to-end instructions, and minor enhancements to logging and utility scripts. Changes
Sequence Diagram(s)sequenceDiagram
participant CLI as Command Line
participant Main as main(args)
participant HF as HuggingFace<br/>Checkpoints
participant Bridge as AutoBridge<br/>Providers
participant Distill as DistillationProvider
participant Config as ConfigContainer
participant Trainer as distill()
CLI->>Main: Parse arguments (student/teacher HF paths, data, parallelism)
Main->>HF: Load student & teacher checkpoints
HF-->>Bridge: Return models
Bridge->>Bridge: Build Megatron providers
Bridge->>Bridge: Override parallelism & training settings
Main->>Distill: Wrap providers with DistillationProvider
Main->>Config: Assemble dataset, optimizer, scheduler,<br/>logging, checkpoint configs
Config-->>Trainer: Pass ConfigContainer
Main->>Trainer: Execute distill(config)
Trainer->>Trainer: Create output/checkpoint directories
Trainer->>Trainer: Run distributed training loop
Trainer-->>Main: Report completion
Main->>Main: Cleanup distributed environment
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #861 +/- ##
==========================================
- Coverage 73.72% 73.44% -0.28%
==========================================
Files 196 197 +1
Lines 20457 20657 +200
==========================================
+ Hits 15082 15172 +90
- Misses 5375 5485 +110 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| ) | ||
|
|
||
| print_rank_0("\nStarting distillation...") | ||
| distill(config) |
There was a problem hiding this comment.
Should we make it like the Nemo one where it can do either pretrain(), distill(), or finetune() all in one file? (@ChenhanYu would that be preferred?)
There was a problem hiding this comment.
How about we incrementally extend this file as we get to needing these options?
There was a problem hiding this comment.
Maybe I should rename to train.py?
There was a problem hiding this comment.
I guess right now we can easily just put a pretrain() call if the KD-specific args aren't provided.
SFT can be done later since dataset/template/etc is different.
|
|
||
| ```bash | ||
| python /opt/Megatron-Bridge/3rdparty/Model-Optimizer/examples/megatron_bridge/prune_minitron.py --help | ||
| torchrun --nproc_per_node 1 /opt/Megatron-Bridge/3rdparty/Model-Optimizer/examples/megatron_bridge/prune_minitron.py --help |
There was a problem hiding this comment.
I want to only print help on rank 0 so need to initialize multiprocesses which will only happen on torchrun. I am not spawning multiprocesses in the script so running with python ... will also result in an error trying to fine RANK env variable during dist setup
There was a problem hiding this comment.
mbridge has it's own print_rank_0 which accounts for that
There was a problem hiding this comment.
We should additionally change our own print_rank_0 to work without dist initialized
There was a problem hiding this comment.
Our print_rank_0 works fine on non-dist env. The issue here is I am manually doing dist.setup() which fails if not running with torchrun. Since I am doing all low-level M-bridge stuff myself because of lack of top-level APIs, we dont get the dist setup from M-Bridge
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
8d780a5 to
48c74bd
Compare
eb0aa58 to
ce4d081
Compare
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
d0f930f to
1df11df
Compare
1df11df to
808c1e0
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/megatron_bridge/README.md (1)
11-11:⚠️ Potential issue | 🟡 MinorGrammatical error: "distillation" → "distilling".
"Examples of distillation a pruned or quantized model" should read "Examples of distilling a pruned or quantized model".
Proposed fix
-| Distillation | Examples of distillation a pruned or quantized model | \[[Link](`#distillation`)\] | | +| Distillation | Examples of distilling a pruned or quantized model | \[[Link](`#distillation`)\] | |
🤖 Fix all issues with AI agents
In `@examples/megatron_bridge/distill.py`:
- Around line 120-122: Make --use_mock_data and --data_paths mutually exclusive
instead of silently letting mock data win: when building the CLI parser, create
a mutually exclusive group via parser.add_mutually_exclusive_group() and add the
two flags to that group (referencing args.use_mock_data and args.data_paths),
then remove the manual validation block that raises ValueError for
neither-provided; this ensures argparse enforces exclusivity and you can keep
the later code path that reads data_paths unchanged.
- Around line 132-134: The code computes checkpoint_dir and tensorboard_dir in
main(args: argparse.Namespace) but never ensures they exist; add explicit
directory creation before these paths are passed into
CheckpointConfig/LoggerConfig by calling os.makedirs(checkpoint_dir,
exist_ok=True) and os.makedirs(tensorboard_dir, exist_ok=True) (ensure imports
include os if not already) so the directories derived from args.output_dir are
created ahead of use.
🧹 Nitpick comments (3)
modelopt/torch/utils/plugins/megatron_preprocess_data.py (1)
113-116: Minor:num2hrbon small document counts displays decimals (e.g."5.00 docs").When
countis small,num2hrbformats it as"5.00"which reads slightly oddly for a document count. This is cosmetic and doesn't affect functionality — just worth noting if you want polished early-iteration output.examples/megatron_bridge/README.md (1)
36-38: Hardcoded Python 3.12 path in site-packages mount is fragile.The volume mount path
/opt/venv/lib/python3.12/site-packages/modeloptassumes the NeMo container uses Python 3.12. This will silently break if a future container version changes the Python version. Since you pinnemo:26.02, this is acceptable for now, but consider adding a comment noting the Python version dependency so future maintainers know to update this path.Suggested comment
-v ${MODELOPT_DIR}:/opt/Model-Optimizer \ - -v ${MODELOPT_DIR}/modelopt:/opt/venv/lib/python3.12/site-packages/modelopt \ + -v ${MODELOPT_DIR}/modelopt:/opt/venv/lib/python3.12/site-packages/modelopt \ # Update python3.12 if container Python version changesexamples/megatron_bridge/distill.py (1)
162-168: Hardcodedadam_beta2=0.98— consider exposing as CLI arg or documenting the choice.
adam_beta2=0.98differs from the common default of0.999. While0.98is reasonable for distillation/pre-training, it's not configurable via CLI. A comment explaining the choice would help users who want to tune this.
808c1e0 to
86e81b1
Compare
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
86e81b1 to
50b6b7e
Compare
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
| dataset_kwargs = { | ||
| "seq_length": args.seq_length, | ||
| "path_to_cache": args.data_path_to_cache, | ||
| "random_seed": SEED, |
There was a problem hiding this comment.
can the seed also be a random arg?
There was a problem hiding this comment.
You mean randomly generate everytime? Then the results may not be reproducible
| --ulimit memlock=-1 \ | ||
| --rm -it \ | ||
| -v ${MODELOPT_DIR}:/opt/Model-Optimizer \ | ||
| -v ${MODELOPT_DIR}/modelopt:/opt/venv/lib/python3.12/site-packages/modelopt \ |
There was a problem hiding this comment.
why is mounting to venv also necessary?
There was a problem hiding this comment.
So users can mount library and examples from same version. This avoids the case where user uses old modelopt but with examples from main branch
| To convert the Megatron checkpoint from last iteration (or any intermediate iteration) to Hugging Face format, you need the pruned model config (`--output_hf_path` from `prune_minitron.py` script) and the distilled megatron checkpoint dir (`<distill_output_dir>/checkpoints/iter_<iter_number>`) to run the following command: | ||
|
|
||
| ```bash | ||
| uv run python /opt/Megatron-Bridge/examples/conversion/convert_checkpoints.py export \ |
There was a problem hiding this comment.
do we assume the user already has uv installed?
There was a problem hiding this comment.
Its in the nemo container so already installed
## What does this PR do?
**Type of change:** New example script <!-- Use one of the following:
Bug fix, new feature, new example, new tests, documentation. -->
- [x] M-Bridge recipe-free distillation script so its more easier to run
and can support pruned models
- [x] Fix resuming distillation run
## Usage
<!-- You can potentially add a usage example below. -->
```python
torchrun --nproc_per_node 8 distill.py \
--teacher_hf_path Qwen/Qwen3-8B \
--student_hf_path Qwen3-8B-NAS-Pruned-6B \
--tp_size 8 \
--data_paths <climbmix 25% tokenized (~90B tokens)> \
--data_path_to_cache /path/to/cache/climbmix_dataset_indices_qwen3 \
--seq_length 4096 \
--mbs 8 \
--gbs 768 \
--train_iters 28500 \
--lr 1e-4 \
--min_lr 1e-5 \
--lr_warmup_iters 100 \
--eval_interval 500 \
--eval_iters 32 \
--log_interval 10 \
--output_dir qwen3_8b_6b_mbridge_distill
```
## Testing
<!-- Mention how have you tested your change if applicable. -->
- [x] Re-ran Qwen3 8B -> 6B experiments and compare with Nemo2 results
from blog
Best subnet from NAS: `{'num_layers': 30, 'hidden_size': 3584,
'ffn_hidden_size': 11776} -> 5.99B params, 0.5718 score`
| Model | MMLU | GSM8K - flexible, strict | MBPP (coding) |
| ------- | ------ | ------- | ------- |
| Qwen3-8B | 74.9 | 87.5, 84.6 | 65.4 |
| Qwen3-8B-Pruned-6B | 57.6 | 11.6, 10.0 | 4.8 |
| Qwen3-8B-Pruned-6B (Distilled for 16k steps i.e. 50B tokens ~3k GPU
hours) | 71.6 | 78.0, 64.7 | 43.4 |
| Qwen3-8B-Pruned-6B (Distilled for 28.5k steps i.e. 90B tokens ~5.2k
GPU hours) | 71.9 | 78.1, 64.8 | 44.2 |
| Qwen3-4B | 70.0 | 81.1, 84.7 | 62.8 |
Previous Nemo2 experiments on depth pruned Qwen3 8B -> 6B (24 layers)
had MMLU ~72.0 so more or less similar. No hparam tuning done for
current M-Bridge distillation run
- [ ] (Separate PR) GitHub CI/CD test for example script with NeMo 26.02
container
## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->
- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes <!--- If No, explain why.
-->
- **Did you write any new necessary tests?**: N/A
- **Did you add or update any necessary documentation?**: Yes
- **Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:
Yes <!--- Only for new features, API changes, critical bug fixes or bw
breaking changes. -->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
## Release Notes
* **New Features**
* Added complete distillation workflow and example for Megatron-Bridge
optimization.
* **Documentation**
* Enhanced setup guide with Docker workflows, data preparation steps,
and detailed distillation instructions.
* Improved usage documentation and help references.
* **Improvements**
* Better data preprocessing output with human-readable formatting for
metrics.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
## What does this PR do?
**Type of change:** New example script <!-- Use one of the following:
Bug fix, new feature, new example, new tests, documentation. -->
- [x] M-Bridge recipe-free distillation script so its more easier to run
and can support pruned models
- [x] Fix resuming distillation run
## Usage
<!-- You can potentially add a usage example below. -->
```python
torchrun --nproc_per_node 8 distill.py \
--teacher_hf_path Qwen/Qwen3-8B \
--student_hf_path Qwen3-8B-NAS-Pruned-6B \
--tp_size 8 \
--data_paths <climbmix 25% tokenized (~90B tokens)> \
--data_path_to_cache /path/to/cache/climbmix_dataset_indices_qwen3 \
--seq_length 4096 \
--mbs 8 \
--gbs 768 \
--train_iters 28500 \
--lr 1e-4 \
--min_lr 1e-5 \
--lr_warmup_iters 100 \
--eval_interval 500 \
--eval_iters 32 \
--log_interval 10 \
--output_dir qwen3_8b_6b_mbridge_distill
```
## Testing
<!-- Mention how have you tested your change if applicable. -->
- [x] Re-ran Qwen3 8B -> 6B experiments and compare with Nemo2 results
from blog
Best subnet from NAS: `{'num_layers': 30, 'hidden_size': 3584,
'ffn_hidden_size': 11776} -> 5.99B params, 0.5718 score`
| Model | MMLU | GSM8K - flexible, strict | MBPP (coding) |
| ------- | ------ | ------- | ------- |
| Qwen3-8B | 74.9 | 87.5, 84.6 | 65.4 |
| Qwen3-8B-Pruned-6B | 57.6 | 11.6, 10.0 | 4.8 |
| Qwen3-8B-Pruned-6B (Distilled for 16k steps i.e. 50B tokens ~3k GPU
hours) | 71.6 | 78.0, 64.7 | 43.4 |
| Qwen3-8B-Pruned-6B (Distilled for 28.5k steps i.e. 90B tokens ~5.2k
GPU hours) | 71.9 | 78.1, 64.8 | 44.2 |
| Qwen3-4B | 70.0 | 81.1, 84.7 | 62.8 |
Previous Nemo2 experiments on depth pruned Qwen3 8B -> 6B (24 layers)
had MMLU ~72.0 so more or less similar. No hparam tuning done for
current M-Bridge distillation run
- [ ] (Separate PR) GitHub CI/CD test for example script with NeMo 26.02
container
## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->
- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes <!--- If No, explain why.
-->
- **Did you write any new necessary tests?**: N/A
- **Did you add or update any necessary documentation?**: Yes
- **Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:
Yes <!--- Only for new features, API changes, critical bug fixes or bw
breaking changes. -->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
## Release Notes
* **New Features**
* Added complete distillation workflow and example for Megatron-Bridge
optimization.
* **Documentation**
* Enhanced setup guide with Docker workflows, data preparation steps,
and detailed distillation instructions.
* Improved usage documentation and help references.
* **Improvements**
* Better data preprocessing output with human-readable formatting for
metrics.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
What does this PR do?
Type of change: New example script
Usage
Testing
Best subnet from NAS:
{'num_layers': 30, 'hidden_size': 3584, 'ffn_hidden_size': 11776} -> 5.99B params, 0.5718 scorePrevious Nemo2 experiments on depth pruned Qwen3 8B -> 6B (24 layers) had MMLU ~72.0 so more or less similar. No hparam tuning done for current M-Bridge distillation run
Before your PR is "Ready for review"
Summary by CodeRabbit
Release Notes
New Features
Documentation
Improvements